{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from transformers import pipeline, AutoTokenizer, GenerationConfig, PhiForCausalLM\n",
    "from langchain.document_loaders import TextLoader\n",
    "\n",
    "from langchain.text_splitter import CharacterTextSplitter\n",
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "from langchain.vectorstores import Chroma\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载和分割本地知识文档\n",
    "这里以2024年1月11号发射的[快舟一号甲](https://baike.baidu.com/item/快舟一号甲)的百科词条为例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载本地词向量模型，使用的是 https://huggingface.co/BAAI/bge-base-zh\n",
    "# model_name = \"./data/BAAI_bge-base-zh\"\n",
    "model_name = \"BAAI/bge-base-zh\"\n",
    "model_kwargs = {'device': 'cuda'}\n",
    "encode_kwargs = {'normalize_embeddings': True}\n",
    "\n",
    "embedding = HuggingFaceBgeEmbeddings(\n",
    "                model_name=model_name,\n",
    "                model_kwargs=model_kwargs,\n",
    "                encode_kwargs=encode_kwargs,\n",
    "                query_instruction=\"为文本生成向量表示用于文本检索\"\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Created a chunk of size 261, which is longer than the specified 96\n",
      "Created a chunk of size 963, which is longer than the specified 96\n",
      "Created a chunk of size 551, which is longer than the specified 96\n",
      "Created a chunk of size 499, which is longer than the specified 96\n",
      "Created a chunk of size 104, which is longer than the specified 96\n"
     ]
    }
   ],
   "source": [
    "doc_db_save_dir = './model_save/vector'\n",
    "\n",
    "if not os.path.exists(doc_db_save_dir):\n",
    "\n",
    "    # 1. 从文件读取本地数据集\n",
    "    loader = TextLoader(\"./data/快舟一号甲.txt\")\n",
    "    documents = loader.load()\n",
    "\n",
    "    # 2. 拆分文档\n",
    "    text_splitter = CharacterTextSplitter(chunk_size=96, chunk_overlap=8)\n",
    "    splited_documents = text_splitter.split_documents(documents)\n",
    "\n",
    "    # 3. 向量化并保存到本地目录\n",
    "\n",
    "    db = Chroma.from_documents(splited_documents, embedding, persist_directory=doc_db_save_dir)\n",
    "    db.persist()\n",
    "else:\n",
    "    db = Chroma(persist_directory=doc_db_save_dir,  embedding_function=embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载对话模型并构造对话prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = './model_save/dpo/'\n",
    "\n",
    "model = PhiForCausalLM.from_pretrained(model_id).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "\n",
    "phi_pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"快舟一号甲的近地轨道运载能力是多少？\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "请根据以下给出的背景知识回答问题，对于不知道的信息，直接回答“未找到相关答案”。\n",
      "以下为为背景知识：\n",
      "0. 快舟一号甲:\n",
      "快舟一号甲（英文：Kuaizhou-1A，简称：KZ-1A），是由中国航天科工火箭技术有限公司研制的三级固体运载火箭。\n",
      "快舟一号甲运载火箭全长约20米，起飞质量约30吨，整流罩最大直径1.4米，太阳同步圆轨道的运载能力为200千克/700千米，近地轨道运载能力为300千克。火箭采用车载机动发射方式，主要面向微小卫星发射和组网，具备一箭多星发射能力。\n",
      "2024年1月11日11时52分，中国在酒泉卫星发射中心使用快舟一号甲运载火箭，成功将天行一号02星发射升空，卫星顺利进入预定轨道，发射任务获得圆满成功。\n",
      "以下为问题：\n",
      "快舟一号甲的近地轨道运载能力是多少？\n"
     ]
    }
   ],
   "source": [
    "# 构造prompt\n",
    "template = \"请根据以下给出的背景知识回答问题，对于不知道的信息，直接回答“未找到相关答案”。\\n以下为为背景知识：\\n\"\n",
    "\n",
    "similar_docs = db.similarity_search(question, k = 1)\n",
    "for i, doc in enumerate(similar_docs):\n",
    "    template += f\"{i}. {doc.page_content}\"\n",
    "\n",
    "template += f'\\n以下为问题：\\n{question}'\n",
    "print(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "快艇一号甲的近地轨道运载能力为300千克。\n"
     ]
    }
   ],
   "source": [
    "prompt = f\"##提问:\\n{template}\\n##回答:\\n\"\n",
    "outputs = phi_pipe(prompt, num_return_sequences=1, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "print(outputs[0]['generated_text'][len(prompt): ])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
